{
 "cells": [
  {
   "cell_type": "raw",
   "source": [
    "---\n",
    "sidebar_label: YUAN2\n",
    "---"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% raw\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# YUAN2.0\n",
    "\n",
    "This notebook shows how to use [YUAN2 API](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/docs/inference_server.md) in LangChain with the langchain.chat_models.ChatYuan2.\n",
    "\n",
    "[*Yuan2.0*](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md) is a new generation Fundamental Large Language Model developed by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan 2.0-51B, and Yuan 2.0-2B. And we provide relevant scripts for pretraining, fine-tuning, and inference services for other developers. Yuan2.0 is based on Yuan1.0, utilizing a wider range of high-quality pre training data and instruction fine-tuning datasets to enhance the model's understanding of semantics, mathematics, reasoning, code, knowledge, and other aspects."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Getting started\n",
    "### Installation\n",
    "First, Yuan2.0 provided an OpenAI compatible API, and we integrate ChatYuan2 into langchain chat model by using OpenAI client.\n",
    "Therefore, ensure the openai package is installed in your Python environment. Run the following command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%pip install --upgrade --quiet openai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Importing the Required Modules\n",
    "After installation, import the necessary modules to your Python script:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_community.chat_models import ChatYuan2\n",
    "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Setting Up Your API server\n",
    "Setting up your OpenAI compatible API server following [yuan2 openai api server](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md).\n",
    "If you deployed api server locally, you can simply set `api_key=\"EMPTY\"` or anything you want.\n",
    "Just make sure, the `api_base` is set correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "yuan2_api_key = \"your_api_key\"\n",
    "yuan2_api_base = \"http://127.0.0.1:8001/v1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Initialize the ChatYuan2 Model\n",
    "Here's how to initialize the chat model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "chat = ChatYuan2(\n",
    "    yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
    "    temperature=1.0,\n",
    "    model_name=\"yuan2\",\n",
    "    max_retries=3,\n",
    "    streaming=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Basic Usage\n",
    "Invoke the model with system and human messages like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "messages = [\n",
    "    SystemMessage(content=\"你是一个人工智能助手。\"),\n",
    "    HumanMessage(content=\"你好，你是谁？\"),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "print(chat(messages))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Basic Usage with streaming\n",
    "For continuous interaction, use the streaming feature:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
    "\n",
    "chat = ChatYuan2(\n",
    "    yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
    "    temperature=1.0,\n",
    "    model_name=\"yuan2\",\n",
    "    max_retries=3,\n",
    "    streaming=True,\n",
    "    callbacks=[StreamingStdOutCallbackHandler()],\n",
    ")\n",
    "messages = [\n",
    "    SystemMessage(content=\"你是个旅游小助手。\"),\n",
    "    HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "chat(messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Advanced Features\n",
    "### Usage with async calls\n",
    "\n",
    "Invoke the model with non-blocking calls, like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "async def basic_agenerate():\n",
    "    chat = ChatYuan2(\n",
    "        yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
    "        temperature=1.0,\n",
    "        model_name=\"yuan2\",\n",
    "        max_retries=3,\n",
    "    )\n",
    "    messages = [\n",
    "        [\n",
    "            SystemMessage(content=\"你是个旅游小助手。\"),\n",
    "            HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
    "        ]\n",
    "    ]\n",
    "\n",
    "    result = await chat.agenerate(messages)\n",
    "    print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import asyncio\n",
    "\n",
    "asyncio.run(basic_agenerate())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Usage with prompt template\n",
    "\n",
    "Invoke the model with non-blocking calls and used chat template like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "async def ainvoke_with_prompt_template():\n",
    "    from langchain.prompts.chat import (\n",
    "        ChatPromptTemplate,\n",
    "    )\n",
    "\n",
    "    chat = ChatYuan2(\n",
    "        yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
    "        temperature=1.0,\n",
    "        model_name=\"yuan2\",\n",
    "        max_retries=3,\n",
    "    )\n",
    "    prompt = ChatPromptTemplate.from_messages(\n",
    "        [\n",
    "            (\"system\", \"你是一个诗人，擅长写诗。\"),\n",
    "            (\"human\", \"给我写首诗，主题是{theme}。\"),\n",
    "        ]\n",
    "    )\n",
    "    chain = prompt | chat\n",
    "    result = await chain.ainvoke({\"theme\": \"明月\"})\n",
    "    print(f\"type(result): {type(result)}; {result}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "asyncio.run(ainvoke_with_prompt_template())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Usage with async calls in streaming\n",
    "For non-blocking calls with streaming output, use the astream method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "async def basic_astream():\n",
    "    chat = ChatYuan2(\n",
    "        yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
    "        temperature=1.0,\n",
    "        model_name=\"yuan2\",\n",
    "        max_retries=3,\n",
    "    )\n",
    "    messages = [\n",
    "        SystemMessage(content=\"你是个旅游小助手。\"),\n",
    "        HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
    "    ]\n",
    "    result = chat.astream(messages)\n",
    "    async for chunk in result:\n",
    "        print(chunk.content, end=\"\", flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import asyncio\n",
    "\n",
    "asyncio.run(basic_astream())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}